CapsGAN

A genreative Adversarial Network using a Capsule Network (CapsNet) as a discriminator

References

Dynamic Routing Between Capsules, Sara Sabour et al. https://arxiv.org/pdf/1710.09829.pdf

Capsule Networks (CapsNets), Aurélien Geron https://github.com/ageron/handson-ml/blob/master/extra_capsnets.ipynb

CapsuleNet on MNIST, Kevin Mader https://www.kaggle.com/kmader/capsulenet-on-mnist

capsule-GAN, Guseyn Gadirov https://github.com/gusgad/capsule-GAN/blob/master/capsule_gan.ipynb

In [1]:
import numpy as np
import os
import pandas as pd
from keras.preprocessing.image import ImageDataGenerator
from keras import callbacks
from keras.utils.vis_utils import plot_model
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import Graph, Session
import keras.backend as K
from keras import backend as K
from keras import Sequential, models, layers

import keras
from keras.models import Model,Sequential
from keras.datasets import mnist
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import adam
from keras.layers import Dense, Activation, Dropout, Convolution2D, Flatten, MaxPooling2D, Reshape, InputLayer, Input
from keras.utils import to_categorical
from keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose
from keras import initializers
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Lambda, Concatenate, Multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.optimizers import Adam


%matplotlib inline
Using TensorFlow backend.

Squash helper function

In [0]:
def squash(vector, axis=-1, epsilon=1e-7):
    
    s_squared_norm = K.sum(K.square(vector), axis, keepdims=True)
    safe_norm =  K.sqrt(s_squared_norm + epsilon)
    squash_factor = s_squared_norm / (1.0 + s_squared_norm)
    unit_vector = vector / safe_norm
    
    return squash_factor * unit_vector

Primary Capsules Layer

In [0]:
def primaryCaps(inputs, dim_vector, n_channels, kernel_size, strides, padding):
    '''
    inputs : output of first convolutional layer of shape (?x20x20x256)
    dim_vector : the dimension of the output vector of each capsule = 8
    n_channels : the number of different capsules = 32 
    
    output : output of the convultional layer reshaped to ?x1152x8 and squash function applied to it
    
    '''
    
    conv2 = Conv2D(filters=dim_vector*n_channels, kernel_size=kernel_size, strides=strides, padding=padding)(inputs)
    out = layers.Reshape(target_shape=[-1, dim_vector])(conv2)
    
    return layers.Lambda(squash)(out)

Building the CapsNet discriminator

This implementation uses dense layers as a replacement for the conventional Digit Capsules layer

In [0]:
def build_capsNet():
    
    
    #input shape of the mnist image
    input_shape=(28,28,1)
    
    #this capsnet only has 1 class as it only needs to distinguish if the image is real or fake
    num_classes=1
    
    
    #the image input to the model
    x_in = layers.Input(shape=input_shape)
    
    
    #first concolutional layer
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', name='conv1')(x_in)
    conv1 = LeakyReLU(alpha=0.1)(conv1)
    conv1 = BatchNormalization(momentum=0.8)(conv1)

    
    # primary capsules layer
    prim = primaryCaps(conv1, dim_vector=8, n_channels=32, kernel_size=9, strides=2, padding='valid')
    prim = BatchNormalization(momentum=0.8)(prim)

    
    # Digit capsules layer implemented as normal dense layers
    x = Flatten()(prim)
    
    uhat = Dense(160, kernel_initializer='he_normal', bias_initializer='zeros', name='uhat_digitcaps')(x)
    
    c = Activation('softmax', name='softmax_digitcaps1')(uhat)
    c = Dense(160)(c)
    x = Multiply()([uhat, c])
    s_j = LeakyReLU()(x)
    
    
    c = Activation('softmax', name='softmax_digitcaps2')(s_j) 
    c = Dense(160)(c) 
    x = Multiply()([uhat, c])
    s_j = LeakyReLU()(x)

    c = Activation('softmax', name='softmax_digitcaps3')(s_j) 
    c = Dense(160)(c) 
    x = Multiply()([uhat, c])
    s_j = LeakyReLU()(x)
    
    
    pred = Dense(1, activation='sigmoid')(s_j)

    
    model = models.Model(x_in , pred)
    #model.compile(loss=margin_loss, optimizer=Adam(0.0002, 0.5),metrics={'pred': 'accuracy'})
    model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
    
    return model

Defining the margin loss

used for calculating the loss of the CapsNet

In [0]:
def margin_loss(y_true, y_pred):
    
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))

    return K.mean(K.sum(L, 1))
In [6]:
# Define the model
discriminator = build_capsNet()
discriminator.summary()
WARNING: Logging before flag parsing goes to stderr.
W0805 17:03:46.793862 140105437140864 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0805 17:03:46.818551 140105437140864 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0805 17:03:46.823836 140105437140864 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0805 17:03:46.859152 140105437140864 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W0805 17:03:46.863372 140105437140864 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.

W0805 17:03:47.835672 140105437140864 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:1834: The name tf.nn.fused_batch_norm is deprecated. Please use tf.compat.v1.nn.fused_batch_norm instead.

W0805 17:03:48.322421 140105437140864 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

W0805 17:03:48.339775 140105437140864 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 20, 20, 256)  20992       input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 20, 20, 256)  0           conv1[0][0]                      
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 20, 20, 256)  1024        leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 6, 6, 256)    5308672     batch_normalization_1[0][0]      
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 1152, 8)      0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 1152, 8)      0           reshape_1[0][0]                  
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 1152, 8)      32          lambda_1[0][0]                   
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 9216)         0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
uhat_digitcaps (Dense)          (None, 160)          1474720     flatten_1[0][0]                  
__________________________________________________________________________________________________
softmax_digitcaps1 (Activation) (None, 160)          0           uhat_digitcaps[0][0]             
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 160)          25760       softmax_digitcaps1[0][0]         
__________________________________________________________________________________________________
multiply_1 (Multiply)           (None, 160)          0           uhat_digitcaps[0][0]             
                                                                 dense_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 160)          0           multiply_1[0][0]                 
__________________________________________________________________________________________________
softmax_digitcaps2 (Activation) (None, 160)          0           leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 160)          25760       softmax_digitcaps2[0][0]         
__________________________________________________________________________________________________
multiply_2 (Multiply)           (None, 160)          0           uhat_digitcaps[0][0]             
                                                                 dense_2[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 160)          0           multiply_2[0][0]                 
__________________________________________________________________________________________________
softmax_digitcaps3 (Activation) (None, 160)          0           leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 160)          25760       softmax_digitcaps3[0][0]         
__________________________________________________________________________________________________
multiply_3 (Multiply)           (None, 160)          0           uhat_digitcaps[0][0]             
                                                                 dense_3[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 160)          0           multiply_3[0][0]                 
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 1)            161         leaky_re_lu_4[0][0]              
==================================================================================================
Total params: 6,882,881
Trainable params: 6,882,353
Non-trainable params: 528
__________________________________________________________________________________________________

Load the mnist dataset

In [7]:
from keras.datasets import mnist

def load_data():
    
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    x_test = (x_test.astype(np.float32) - 127.5)/127.5
    
    x_train = x_train.reshape(-1,28,28,1)
    x_test = x_test.reshape(-1,28,28,1)
    
    y_train = to_categorical(y_train.astype('float32'))
    y_test = to_categorical(y_test.astype('float32'))
    
    return (x_train, y_train, x_test, y_test)


(x_train, y_train,x_test, y_test)=load_data()

print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
(60000, 28, 28, 1) (60000, 10) (10000, 28, 28, 1) (10000, 10)

Building the Generator model

The generator's architecture is the same as of the architecture of a regular Deep Convolutional GAN - DCGAN

In [0]:
def build_generator():
    
    noise_shape = (100,)
    x_noise = Input(shape=noise_shape)
    
    x = Dense(128*7*7, activation="relu")(x_noise)
    x = Reshape((7, 7, 128))(x)
    x = UpSampling2D()(x)
    x = Conv2D(128, kernel_size=3, padding="same")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Activation("relu")(x)
    x = UpSampling2D()(x)
    x = Conv2D(64, kernel_size=3, padding="same")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Activation("relu")(x)
    x = Conv2D(1, kernel_size=3, padding="same")(x)
    x_out = Activation('tanh')(x)
    
    model = models.Model(x_noise, x_out)
    model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    
    
    return model
In [9]:
generator = build_generator()
generator.summary()
W0805 17:03:49.147365 140105437140864 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:2018: The name tf.image.resize_nearest_neighbor is deprecated. Please use tf.compat.v1.image.resize_nearest_neighbor instead.

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 100)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 6272)              633472    
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 128)       147584    
_________________________________________________________________
batch_normalization_3 (Batch (None, 14, 14, 128)       512       
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 28, 28, 128)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 64)        73792     
_________________________________________________________________
batch_normalization_4 (Batch (None, 28, 28, 64)        256       
_________________________________________________________________
activation_2 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 1)         577       
_________________________________________________________________
activation_3 (Activation)    (None, 28, 28, 1)         0         
=================================================================
Total params: 856,193
Trainable params: 855,809
Non-trainable params: 384
_________________________________________________________________

Function to plot the generated images

In [0]:
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):
    noise= np.random.normal(loc=0, scale=1, size=[examples, 100])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(100,28,28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('gan_generated_image %d.png' %epoch)

Defining the Adversarial model

In [0]:
def create_gan(g, d):
    
    z = Input(shape=[100])
    img = g(z)
    v = d(g(z))
    
    gan = models.Model(z, v)
    gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    
    return gan
In [12]:
gan = create_gan(generator, discriminator)
gan.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         (None, 100)               0         
_________________________________________________________________
model_2 (Model)              (None, 28, 28, 1)         856193    
_________________________________________________________________
model_1 (Model)              (None, 1)                 6882881   
=================================================================
Total params: 7,739,074
Trainable params: 7,738,162
Non-trainable params: 912
_________________________________________________________________

Training the model

In [0]:
def train(data, epochs, batch_size=128):
    
    #unpack data
    x_train = data
    
    K.clear_session()
    
    gen = build_generator()
    disc = build_capsNet()
    gan = create_gan(gen, disc)
    
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    
    for e in range(epochs+1):
        
        real_indxs = np.random.randint(0, x_train.shape[0], batch_size)
        real_imgs = x_train[real_indxs]
        
        noise = np.random.normal(0,1, [batch_size, 100])
        gen_imgs = gen.predict(noise)
        
        x = np.concatenate([real_imgs, gen_imgs])
        y = np.concatenate([valid, fake])
        
        #Training the Discriminator
        disc.Trainable = True
        
        d_loss_real = disc.train_on_batch(real_imgs, valid*0.9) # 0.9 for label smoothing
        d_loss_fake = disc.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        #Training the generator 
        disc.Trainable = False
        
        noise = np.random.normal(0,1, [batch_size, 100])
        y_gan = np.ones((batch_size, 1))
        
        g_loss = gan.train_on_batch((noise), [y_gan])
        
        if e%100 == 0:
          
          #print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (e, d_loss, 100*d_loss, g_loss))
          print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (e, d_loss[0], 100*d_loss[1], g_loss))
        
        if e % 1000 == 0:
            plot_generated_images(e, gen)
In [14]:
train(x_train, epochs=30000)
0 [D loss: 0.692452, acc.: 25.78%] [G loss: 0.703323]
100 [D loss: 0.690022, acc.: 0.00%] [G loss: 0.543523]
200 [D loss: 0.653430, acc.: 0.00%] [G loss: 0.559143]
300 [D loss: 0.663862, acc.: 0.39%] [G loss: 0.576408]
400 [D loss: 0.678835, acc.: 0.00%] [G loss: 0.527379]
500 [D loss: 0.757925, acc.: 0.00%] [G loss: 0.502246]
600 [D loss: 0.754791, acc.: 0.00%] [G loss: 0.482153]
700 [D loss: 0.771630, acc.: 0.78%] [G loss: 0.478499]
800 [D loss: 0.775678, acc.: 1.17%] [G loss: 0.474733]
900 [D loss: 0.806840, acc.: 0.00%] [G loss: 0.446639]
1000 [D loss: 0.787823, acc.: 0.39%] [G loss: 0.439376]
1100 [D loss: 0.789819, acc.: 0.00%] [G loss: 0.431329]
1200 [D loss: 0.785772, acc.: 0.00%] [G loss: 0.429723]
1300 [D loss: 0.818520, acc.: 0.00%] [G loss: 0.450631]
1400 [D loss: 0.793129, acc.: 0.00%] [G loss: 0.437919]
1500 [D loss: 0.838765, acc.: 0.39%] [G loss: 0.416567]
1600 [D loss: 0.813028, acc.: 0.78%] [G loss: 0.427587]
1700 [D loss: 0.781254, acc.: 0.00%] [G loss: 0.438877]
1800 [D loss: 0.806718, acc.: 0.00%] [G loss: 0.405354]
1900 [D loss: 0.813009, acc.: 0.00%] [G loss: 0.470142]
2000 [D loss: 0.807765, acc.: 0.78%] [G loss: 0.441852]
2100 [D loss: 0.797949, acc.: 0.78%] [G loss: 0.437708]
2200 [D loss: 0.835247, acc.: 0.00%] [G loss: 0.457303]
2300 [D loss: 0.822961, acc.: 0.78%] [G loss: 0.455397]
2400 [D loss: 0.801394, acc.: 0.00%] [G loss: 0.443033]
2500 [D loss: 0.812078, acc.: 1.17%] [G loss: 0.460489]
2600 [D loss: 0.801116, acc.: 0.78%] [G loss: 0.477230]
2700 [D loss: 0.810231, acc.: 0.78%] [G loss: 0.483210]
2800 [D loss: 0.785770, acc.: 0.00%] [G loss: 0.473317]
2900 [D loss: 0.807290, acc.: 0.39%] [G loss: 0.460093]
3000 [D loss: 0.772272, acc.: 1.56%] [G loss: 0.478448]
3100 [D loss: 0.782626, acc.: 0.39%] [G loss: 0.455067]
3200 [D loss: 0.786437, acc.: 0.00%] [G loss: 0.477131]
3300 [D loss: 0.766993, acc.: 1.17%] [G loss: 0.474221]
3400 [D loss: 0.792217, acc.: 1.56%] [G loss: 0.485181]
3500 [D loss: 0.785192, acc.: 1.17%] [G loss: 0.443678]
3600 [D loss: 0.764867, acc.: 1.56%] [G loss: 0.482338]
3700 [D loss: 0.781310, acc.: 1.17%] [G loss: 0.478673]
3800 [D loss: 0.769221, acc.: 1.56%] [G loss: 0.483557]
3900 [D loss: 0.763559, acc.: 1.56%] [G loss: 0.483390]
4000 [D loss: 0.775536, acc.: 1.17%] [G loss: 0.446959]
4100 [D loss: 0.752658, acc.: 1.95%] [G loss: 0.510528]
4200 [D loss: 0.762216, acc.: 0.39%] [G loss: 0.526250]
4300 [D loss: 0.766740, acc.: 1.56%] [G loss: 0.525008]
4400 [D loss: 0.760838, acc.: 2.34%] [G loss: 0.519561]
4500 [D loss: 0.771830, acc.: 0.78%] [G loss: 0.510173]
4600 [D loss: 0.775047, acc.: 1.17%] [G loss: 0.505063]
4700 [D loss: 0.729962, acc.: 5.08%] [G loss: 0.519618]
4800 [D loss: 0.747539, acc.: 4.30%] [G loss: 0.516281]
4900 [D loss: 0.756686, acc.: 2.34%] [G loss: 0.517949]
5000 [D loss: 0.768075, acc.: 5.47%] [G loss: 0.528019]
5100 [D loss: 0.752126, acc.: 2.34%] [G loss: 0.549251]
5200 [D loss: 0.780343, acc.: 3.12%] [G loss: 0.530229]
5300 [D loss: 0.763236, acc.: 1.56%] [G loss: 0.519304]
5400 [D loss: 0.802363, acc.: 2.73%] [G loss: 0.515954]
5500 [D loss: 0.741653, acc.: 3.52%] [G loss: 0.541419]
5600 [D loss: 0.792818, acc.: 1.56%] [G loss: 0.507718]
5700 [D loss: 0.756437, acc.: 5.08%] [G loss: 0.548460]
5800 [D loss: 0.800576, acc.: 2.73%] [G loss: 0.531722]
5900 [D loss: 0.781686, acc.: 3.91%] [G loss: 0.500522]
6000 [D loss: 0.779418, acc.: 2.73%] [G loss: 0.554620]
6100 [D loss: 0.779314, acc.: 3.52%] [G loss: 0.508776]
6200 [D loss: 0.772333, acc.: 2.73%] [G loss: 0.539463]
6300 [D loss: 0.794760, acc.: 4.30%] [G loss: 0.546815]
6400 [D loss: 0.789321, acc.: 3.12%] [G loss: 0.511675]
6500 [D loss: 0.812679, acc.: 2.73%] [G loss: 0.514616]
6600 [D loss: 0.782484, acc.: 2.34%] [G loss: 0.523521]
6700 [D loss: 0.769942, acc.: 4.69%] [G loss: 0.531607]
6800 [D loss: 0.771663, acc.: 3.12%] [G loss: 0.608783]
6900 [D loss: 0.810353, acc.: 2.34%] [G loss: 0.523981]
7000 [D loss: 0.787282, acc.: 5.08%] [G loss: 0.559365]
7100 [D loss: 0.796203, acc.: 1.56%] [G loss: 0.504687]
7200 [D loss: 0.780915, acc.: 3.52%] [G loss: 0.530293]
7300 [D loss: 0.839863, acc.: 2.34%] [G loss: 0.567453]
7400 [D loss: 0.802081, acc.: 1.95%] [G loss: 0.547280]
7500 [D loss: 0.789414, acc.: 3.12%] [G loss: 0.608225]
7600 [D loss: 0.771016, acc.: 3.12%] [G loss: 0.568828]
7700 [D loss: 0.806163, acc.: 4.30%] [G loss: 0.546677]
7800 [D loss: 0.795043, acc.: 4.30%] [G loss: 0.551776]
7900 [D loss: 0.778442, acc.: 6.25%] [G loss: 0.562973]
8000 [D loss: 0.776459, acc.: 4.30%] [G loss: 0.584549]
8100 [D loss: 0.793599, acc.: 2.73%] [G loss: 0.533402]
8200 [D loss: 0.812054, acc.: 1.95%] [G loss: 0.546968]
8300 [D loss: 0.772636, acc.: 4.69%] [G loss: 0.571654]
8400 [D loss: 0.794425, acc.: 5.08%] [G loss: 0.579289]
8500 [D loss: 0.790836, acc.: 2.34%] [G loss: 0.531856]
8600 [D loss: 0.782967, acc.: 2.34%] [G loss: 0.526793]
8700 [D loss: 0.772555, acc.: 7.03%] [G loss: 0.618529]
8800 [D loss: 0.819971, acc.: 4.30%] [G loss: 0.547621]
8900 [D loss: 0.792278, acc.: 5.47%] [G loss: 0.627251]
9000 [D loss: 0.802523, acc.: 3.12%] [G loss: 0.552374]
9100 [D loss: 0.780059, acc.: 5.47%] [G loss: 0.587454]
9200 [D loss: 0.791363, acc.: 2.73%] [G loss: 0.601027]
9300 [D loss: 0.786644, acc.: 4.69%] [G loss: 0.576391]
9400 [D loss: 0.755548, acc.: 4.69%] [G loss: 0.543309]
9500 [D loss: 0.775507, acc.: 5.47%] [G loss: 0.611354]
9600 [D loss: 0.779620, acc.: 5.47%] [G loss: 0.602152]
9700 [D loss: 0.804500, acc.: 3.52%] [G loss: 0.605936]
9800 [D loss: 0.769888, acc.: 9.77%] [G loss: 0.582661]
9900 [D loss: 0.758395, acc.: 2.34%] [G loss: 0.610786]
10000 [D loss: 0.797786, acc.: 5.86%] [G loss: 0.614999]
10100 [D loss: 0.796570, acc.: 2.73%] [G loss: 0.595362]
10200 [D loss: 0.779153, acc.: 6.25%] [G loss: 0.646758]
10300 [D loss: 0.783715, acc.: 2.73%] [G loss: 0.581077]
10400 [D loss: 0.756933, acc.: 5.08%] [G loss: 0.563669]
10500 [D loss: 0.761696, acc.: 6.25%] [G loss: 0.641574]
10600 [D loss: 0.752802, acc.: 6.64%] [G loss: 0.615941]
10700 [D loss: 0.736722, acc.: 6.25%] [G loss: 0.625253]
10800 [D loss: 0.791924, acc.: 5.08%] [G loss: 0.700703]
10900 [D loss: 0.779669, acc.: 5.86%] [G loss: 0.689557]
11000 [D loss: 0.756509, acc.: 6.25%] [G loss: 0.662029]
11100 [D loss: 0.762871, acc.: 7.03%] [G loss: 0.753637]
11200 [D loss: 0.751851, acc.: 2.34%] [G loss: 0.677252]
11300 [D loss: 0.791600, acc.: 6.64%] [G loss: 0.795058]
11400 [D loss: 0.738482, acc.: 7.03%] [G loss: 0.658596]
11500 [D loss: 0.731306, acc.: 4.30%] [G loss: 0.681449]
11600 [D loss: 0.815618, acc.: 3.52%] [G loss: 0.754749]
11700 [D loss: 0.764675, acc.: 2.34%] [G loss: 0.675612]
11800 [D loss: 0.777934, acc.: 5.08%] [G loss: 0.715532]
11900 [D loss: 0.752043, acc.: 5.47%] [G loss: 0.794173]
12000 [D loss: 0.773118, acc.: 1.95%] [G loss: 0.668240]
12100 [D loss: 0.786013, acc.: 4.30%] [G loss: 0.705114]
12200 [D loss: 0.747987, acc.: 5.08%] [G loss: 0.709269]
12300 [D loss: 0.797894, acc.: 3.12%] [G loss: 0.713377]
12400 [D loss: 0.788102, acc.: 4.69%] [G loss: 0.770477]
12500 [D loss: 0.744619, acc.: 5.47%] [G loss: 0.726461]
12600 [D loss: 0.761752, acc.: 5.08%] [G loss: 0.721695]
12700 [D loss: 0.794136, acc.: 4.30%] [G loss: 0.715248]
12800 [D loss: 0.766701, acc.: 3.12%] [G loss: 0.745745]
12900 [D loss: 0.726490, acc.: 8.20%] [G loss: 0.721939]
13000 [D loss: 0.781962, acc.: 6.25%] [G loss: 0.675459]
13100 [D loss: 0.783298, acc.: 3.52%] [G loss: 0.707204]
13200 [D loss: 0.759145, acc.: 6.64%] [G loss: 0.739165]
13300 [D loss: 0.749694, acc.: 2.73%] [G loss: 0.627443]
13400 [D loss: 0.767790, acc.: 3.12%] [G loss: 0.718496]
13500 [D loss: 0.741907, acc.: 5.08%] [G loss: 0.668030]
13600 [D loss: 0.745389, acc.: 4.30%] [G loss: 0.704925]
13700 [D loss: 0.790089, acc.: 4.30%] [G loss: 0.689641]
13800 [D loss: 0.754079, acc.: 5.47%] [G loss: 0.769687]
13900 [D loss: 0.757266, acc.: 5.47%] [G loss: 0.797655]
14000 [D loss: 0.774773, acc.: 3.91%] [G loss: 0.697047]
14100 [D loss: 0.753349, acc.: 5.08%] [G loss: 0.715350]
14200 [D loss: 0.743380, acc.: 4.30%] [G loss: 0.720229]
14300 [D loss: 0.777304, acc.: 5.08%] [G loss: 0.695988]
14400 [D loss: 0.778791, acc.: 2.73%] [G loss: 0.741093]
14500 [D loss: 0.749682, acc.: 5.47%] [G loss: 0.711550]
14600 [D loss: 0.754771, acc.: 3.12%] [G loss: 0.752398]
14700 [D loss: 0.741628, acc.: 5.86%] [G loss: 0.754923]
14800 [D loss: 0.724669, acc.: 4.69%] [G loss: 0.743072]
14900 [D loss: 0.750736, acc.: 6.25%] [G loss: 0.722358]
15000 [D loss: 0.749732, acc.: 3.12%] [G loss: 0.733704]
15100 [D loss: 0.743481, acc.: 5.86%] [G loss: 0.766612]
15200 [D loss: 0.763330, acc.: 4.69%] [G loss: 0.733050]
15300 [D loss: 0.763336, acc.: 6.25%] [G loss: 0.723359]
15400 [D loss: 0.765880, acc.: 2.73%] [G loss: 0.735357]
15500 [D loss: 0.778841, acc.: 4.69%] [G loss: 0.735866]
15600 [D loss: 0.778821, acc.: 3.91%] [G loss: 0.783075]
15700 [D loss: 0.731905, acc.: 7.81%] [G loss: 0.784513]
15800 [D loss: 0.777322, acc.: 4.30%] [G loss: 0.713081]
15900 [D loss: 0.758019, acc.: 2.73%] [G loss: 0.735633]
16000 [D loss: 0.796840, acc.: 4.30%] [G loss: 0.713938]
16100 [D loss: 0.745458, acc.: 4.30%] [G loss: 0.725449]
16200 [D loss: 0.752116, acc.: 2.73%] [G loss: 0.718048]
16300 [D loss: 0.765403, acc.: 4.30%] [G loss: 0.699426]
16400 [D loss: 0.748301, acc.: 5.08%] [G loss: 0.800062]
16500 [D loss: 0.775233, acc.: 5.47%] [G loss: 0.743834]
16600 [D loss: 0.751604, acc.: 5.47%] [G loss: 0.847628]
16700 [D loss: 0.829215, acc.: 2.73%] [G loss: 0.764475]
16800 [D loss: 0.767427, acc.: 2.73%] [G loss: 0.748264]
16900 [D loss: 0.793147, acc.: 4.30%] [G loss: 0.770962]
17000 [D loss: 0.757554, acc.: 5.47%] [G loss: 0.718267]
17100 [D loss: 0.808694, acc.: 7.03%] [G loss: 0.765667]
17200 [D loss: 0.780034, acc.: 7.81%] [G loss: 0.815729]
17300 [D loss: 0.738654, acc.: 5.86%] [G loss: 0.864958]
17400 [D loss: 0.740067, acc.: 6.25%] [G loss: 0.827473]
17500 [D loss: 0.739848, acc.: 4.30%] [G loss: 0.830323]
17600 [D loss: 0.765648, acc.: 4.30%] [G loss: 0.953386]
17700 [D loss: 0.733113, acc.: 12.89%] [G loss: 0.812113]
17800 [D loss: 0.754455, acc.: 4.69%] [G loss: 0.826568]
17900 [D loss: 0.765897, acc.: 6.25%] [G loss: 0.892829]
18000 [D loss: 0.752906, acc.: 7.03%] [G loss: 0.838185]
18100 [D loss: 0.756647, acc.: 5.86%] [G loss: 0.852829]
18200 [D loss: 0.732399, acc.: 5.08%] [G loss: 0.827914]
18300 [D loss: 0.801475, acc.: 5.86%] [G loss: 0.903144]
18400 [D loss: 0.753599, acc.: 2.73%] [G loss: 0.930157]
18500 [D loss: 0.738733, acc.: 6.25%] [G loss: 0.923680]
18600 [D loss: 0.789007, acc.: 2.34%] [G loss: 0.896178]
18700 [D loss: 0.764591, acc.: 3.52%] [G loss: 0.962082]
18800 [D loss: 0.781094, acc.: 5.47%] [G loss: 0.889822]
18900 [D loss: 0.782912, acc.: 1.56%] [G loss: 0.894223]
19000 [D loss: 0.745277, acc.: 5.47%] [G loss: 0.983501]
19100 [D loss: 0.768936, acc.: 3.91%] [G loss: 0.926200]
19200 [D loss: 0.742945, acc.: 4.69%] [G loss: 0.913800]
19300 [D loss: 0.780743, acc.: 3.12%] [G loss: 0.956809]
19400 [D loss: 0.759549, acc.: 3.91%] [G loss: 1.000206]
19500 [D loss: 0.719528, acc.: 5.86%] [G loss: 0.921185]
19600 [D loss: 0.798346, acc.: 1.17%] [G loss: 0.988111]
19700 [D loss: 0.772479, acc.: 6.64%] [G loss: 0.942184]
19800 [D loss: 0.764895, acc.: 3.91%] [G loss: 1.029450]
19900 [D loss: 0.741695, acc.: 4.30%] [G loss: 1.069826]
20000 [D loss: 0.781008, acc.: 3.91%] [G loss: 1.054244]
/usr/local/lib/python3.6/dist-packages/matplotlib/pyplot.py:514: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  max_open_warning, RuntimeWarning)
20100 [D loss: 0.759187, acc.: 6.25%] [G loss: 1.003588]
20200 [D loss: 0.755007, acc.: 6.64%] [G loss: 1.014826]
20300 [D loss: 0.745500, acc.: 6.64%] [G loss: 1.040226]
20400 [D loss: 0.766783, acc.: 6.25%] [G loss: 0.973588]
20500 [D loss: 0.775017, acc.: 3.12%] [G loss: 1.040720]
20600 [D loss: 0.730953, acc.: 5.86%] [G loss: 1.023584]
20700 [D loss: 0.722273, acc.: 5.86%] [G loss: 1.077883]
20800 [D loss: 0.759717, acc.: 6.64%] [G loss: 0.992814]
20900 [D loss: 0.762302, acc.: 6.64%] [G loss: 1.089961]
21000 [D loss: 0.676326, acc.: 14.45%] [G loss: 1.151305]
21100 [D loss: 0.722197, acc.: 9.38%] [G loss: 1.161070]
21200 [D loss: 0.839579, acc.: 7.42%] [G loss: 1.716116]
21300 [D loss: 0.798894, acc.: 16.80%] [G loss: 1.737913]
21400 [D loss: 0.712865, acc.: 21.09%] [G loss: 2.021316]
21500 [D loss: 0.675401, acc.: 18.75%] [G loss: 2.071393]
21600 [D loss: 0.511029, acc.: 30.86%] [G loss: 1.846537]
21700 [D loss: 0.447515, acc.: 39.45%] [G loss: 1.776464]
21800 [D loss: 0.500594, acc.: 30.86%] [G loss: 1.661152]
21900 [D loss: 0.465482, acc.: 33.98%] [G loss: 1.782076]
22000 [D loss: 0.469424, acc.: 35.94%] [G loss: 1.814590]
22100 [D loss: 0.456409, acc.: 38.67%] [G loss: 1.654762]
22200 [D loss: 0.433560, acc.: 39.84%] [G loss: 1.750577]
22300 [D loss: 0.449079, acc.: 35.94%] [G loss: 2.033835]
22400 [D loss: 0.443034, acc.: 38.28%] [G loss: 2.006118]
22500 [D loss: 0.512545, acc.: 30.47%] [G loss: 2.153386]
22600 [D loss: 0.528864, acc.: 29.69%] [G loss: 2.397640]
22700 [D loss: 0.486909, acc.: 33.98%] [G loss: 2.143789]
22800 [D loss: 0.544037, acc.: 24.22%] [G loss: 2.257730]
22900 [D loss: 0.546596, acc.: 28.12%] [G loss: 2.208504]
23000 [D loss: 0.565019, acc.: 25.39%] [G loss: 2.154015]
23100 [D loss: 0.604460, acc.: 16.41%] [G loss: 2.210700]
23200 [D loss: 0.592381, acc.: 20.31%] [G loss: 1.868054]
23300 [D loss: 0.645523, acc.: 16.41%] [G loss: 1.962771]
23400 [D loss: 0.649154, acc.: 14.06%] [G loss: 1.825845]
23500 [D loss: 0.622881, acc.: 15.62%] [G loss: 1.741120]
23600 [D loss: 0.671177, acc.: 12.89%] [G loss: 1.848610]
23700 [D loss: 0.669731, acc.: 13.28%] [G loss: 1.830541]
23800 [D loss: 0.674457, acc.: 11.33%] [G loss: 1.721230]
23900 [D loss: 0.671687, acc.: 10.94%] [G loss: 1.607483]
24000 [D loss: 0.675991, acc.: 14.06%] [G loss: 1.642901]
24100 [D loss: 0.752855, acc.: 7.42%] [G loss: 1.649858]
24200 [D loss: 0.660844, acc.: 12.50%] [G loss: 1.753813]
24300 [D loss: 0.704641, acc.: 5.47%] [G loss: 1.681568]
24400 [D loss: 0.708953, acc.: 6.64%] [G loss: 1.829480]
24500 [D loss: 0.677484, acc.: 7.42%] [G loss: 1.958269]
24600 [D loss: 0.711041, acc.: 9.38%] [G loss: 1.846528]
24700 [D loss: 0.694269, acc.: 7.03%] [G loss: 1.930164]
24800 [D loss: 0.682238, acc.: 10.16%] [G loss: 1.871030]
24900 [D loss: 0.754470, acc.: 5.86%] [G loss: 1.891112]
25000 [D loss: 0.672288, acc.: 13.67%] [G loss: 1.923395]
25100 [D loss: 0.684624, acc.: 10.16%] [G loss: 1.823849]
25200 [D loss: 0.724713, acc.: 5.47%] [G loss: 1.716901]
25300 [D loss: 0.690481, acc.: 7.81%] [G loss: 1.808361]
25400 [D loss: 0.706057, acc.: 9.38%] [G loss: 1.726350]
25500 [D loss: 0.694526, acc.: 10.94%] [G loss: 1.907385]
25600 [D loss: 0.675557, acc.: 9.77%] [G loss: 1.695900]
25700 [D loss: 0.727217, acc.: 5.86%] [G loss: 1.892532]
25800 [D loss: 0.692986, acc.: 7.42%] [G loss: 1.790794]
25900 [D loss: 0.697279, acc.: 10.16%] [G loss: 1.794724]
26000 [D loss: 0.644745, acc.: 11.33%] [G loss: 1.847401]
26100 [D loss: 0.682499, acc.: 9.77%] [G loss: 1.799306]
26200 [D loss: 0.680945, acc.: 9.38%] [G loss: 1.982756]
26300 [D loss: 0.705917, acc.: 5.86%] [G loss: 2.004824]
26400 [D loss: 0.708948, acc.: 6.64%] [G loss: 1.907732]
26500 [D loss: 0.708379, acc.: 7.81%] [G loss: 1.889468]
26600 [D loss: 0.705493, acc.: 7.42%] [G loss: 1.804449]
26700 [D loss: 0.716154, acc.: 7.42%] [G loss: 1.671242]
26800 [D loss: 0.740863, acc.: 6.25%] [G loss: 1.919907]
26900 [D loss: 0.759571, acc.: 5.47%] [G loss: 1.945306]
27000 [D loss: 0.659360, acc.: 15.62%] [G loss: 1.774236]
27100 [D loss: 0.678483, acc.: 8.98%] [G loss: 1.777969]
27200 [D loss: 0.717601, acc.: 7.03%] [G loss: 1.987188]
27300 [D loss: 0.740728, acc.: 7.03%] [G loss: 2.080215]
27400 [D loss: 0.687968, acc.: 8.98%] [G loss: 2.188800]
27500 [D loss: 0.732970, acc.: 5.08%] [G loss: 2.133766]
27600 [D loss: 0.754524, acc.: 6.64%] [G loss: 2.262442]
27700 [D loss: 0.670401, acc.: 11.72%] [G loss: 2.146183]
27800 [D loss: 0.732749, acc.: 7.42%] [G loss: 2.380857]
27900 [D loss: 0.715597, acc.: 9.38%] [G loss: 2.427826]
28000 [D loss: 0.724332, acc.: 8.20%] [G loss: 2.388813]
28100 [D loss: 0.725853, acc.: 12.50%] [G loss: 2.518575]
28200 [D loss: 0.777362, acc.: 4.69%] [G loss: 2.714078]
28300 [D loss: 0.686371, acc.: 13.67%] [G loss: 2.693961]
28400 [D loss: 0.632561, acc.: 19.14%] [G loss: 2.539610]
28500 [D loss: 0.663920, acc.: 19.14%] [G loss: 2.858774]
28600 [D loss: 0.645854, acc.: 17.19%] [G loss: 2.862521]
28700 [D loss: 0.666258, acc.: 16.41%] [G loss: 2.833820]
28800 [D loss: 0.777607, acc.: 4.69%] [G loss: 3.445018]
28900 [D loss: 0.688071, acc.: 15.62%] [G loss: 3.108246]
29000 [D loss: 0.702057, acc.: 10.55%] [G loss: 3.623674]
29100 [D loss: 0.655445, acc.: 18.75%] [G loss: 3.830509]
29200 [D loss: 0.719887, acc.: 13.28%] [G loss: 3.848180]
29300 [D loss: 0.696896, acc.: 16.80%] [G loss: 3.823534]
29400 [D loss: 0.648782, acc.: 16.02%] [G loss: 3.948747]
29500 [D loss: 0.692694, acc.: 10.55%] [G loss: 4.274864]
29600 [D loss: 0.755625, acc.: 10.94%] [G loss: 4.435627]
29700 [D loss: 0.592548, acc.: 24.22%] [G loss: 3.961800]
29800 [D loss: 0.720324, acc.: 8.98%] [G loss: 4.523563]
29900 [D loss: 0.545070, acc.: 32.42%] [G loss: 4.114429]
30000 [D loss: 0.724609, acc.: 12.50%] [G loss: 4.778055]